{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# High-level MXNet Example\n",
    "\n",
    "**In the interest of comparison; a common (custom) data-generator (called yield_mb(X, y, batchsize=64, shuffle=False)) was originally used for all other frameworks - but not for MXNet. I have reproduced the MXNet example using this same generator (wrapping the results in the mx.io.DataBatch class) to test if MXNet is faster than other frameworks just because I was using its own data-generator. This does not appear to be the case. **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import mxnet as mx\n",
    "from common.params import *\n",
    "from common.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OS:  linux\n",
      "Python:  3.5.2 |Anaconda custom (64-bit)| (default, Jul  2 2016, 17:53:06) \n",
      "[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]\n",
      "Numpy:  1.13.3\n",
      "MXNet:  0.12.1\n",
      "GPU:  ['Tesla K80']\n"
     ]
    }
   ],
   "source": [
    "print(\"OS: \", sys.platform)\n",
    "print(\"Python: \", sys.version)\n",
    "print(\"Numpy: \", np.__version__)\n",
    "print(\"MXNet: \", mx.__version__)\n",
    "print(\"GPU: \", get_gpu_name())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_symbol():\n",
    "    data = mx.symbol.Variable('data')\n",
    "    # size = [(old-size - kernel + 2*padding)/stride]+1\n",
    "    # if kernel = 3, pad with 1 either side\n",
    "    conv1 = mx.symbol.Convolution(data=data, num_filter=50, pad=(1,1), kernel=(3,3))\n",
    "    relu1 = mx.symbol.Activation(data=conv1, act_type=\"relu\")\n",
    "    conv2 = mx.symbol.Convolution(data=relu1, num_filter=50, pad=(1,1), kernel=(3,3))\n",
    "    pool1 = mx.symbol.Pooling(data=conv2, pool_type=\"max\", kernel=(2,2), stride=(2,2))\n",
    "    relu2 = mx.symbol.Activation(data=pool1, act_type=\"relu\")\n",
    "    drop1 = mx.symbol.Dropout(data=relu2, p=0.25)\n",
    "    \n",
    "    conv3 = mx.symbol.Convolution(data=drop1, num_filter=100, pad=(1,1), kernel=(3,3))\n",
    "    relu3 = mx.symbol.Activation(data=conv3, act_type=\"relu\")\n",
    "    conv4 = mx.symbol.Convolution(data=relu3, num_filter=100, pad=(1,1), kernel=(3,3))\n",
    "    pool2 = mx.symbol.Pooling(data=conv4, pool_type=\"max\", kernel=(2,2), stride=(2,2))\n",
    "    relu4 = mx.symbol.Activation(data=pool2, act_type=\"relu\")\n",
    "    drop2 = mx.symbol.Dropout(data=relu4, p=0.25)\n",
    "           \n",
    "    flat1 = mx.symbol.Flatten(data=drop2)\n",
    "    fc1 = mx.symbol.FullyConnected(data=flat1, num_hidden=512)\n",
    "    relu7 = mx.symbol.Activation(data=fc1, act_type=\"relu\")\n",
    "    drop4 = mx.symbol.Dropout(data=relu7, p=0.5)\n",
    "    fc2 = mx.symbol.FullyConnected(data=drop4, num_hidden=N_CLASSES) \n",
    "    \n",
    "    input_y = mx.symbol.Variable('softmax_label')  \n",
    "    m = mx.symbol.SoftmaxOutput(data=fc2, label=input_y, name=\"softmax\")\n",
    "    return m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_model(m):\n",
    "    if GPU:\n",
    "        ctx = [mx.gpu(0)]\n",
    "    else:\n",
    "        ctx = mx.cpu()\n",
    "    \n",
    "    mod = mx.mod.Module(context=ctx, symbol=m)\n",
    "    mod.bind(data_shapes=[('data', (BATCHSIZE, 3, 32, 32))],\n",
    "             label_shapes=[('softmax_label', (BATCHSIZE,))])\n",
    "\n",
    "    # Glorot-uniform initializer\n",
    "    mod.init_params(initializer=mx.init.Xavier(rnd_type='uniform'))\n",
    "    mod.init_optimizer(optimizer='sgd', \n",
    "                       optimizer_params=(('learning_rate', LR), ('momentum', MOMENTUM), ))\n",
    "    return mod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing train set...\n",
      "Preparing test set...\n",
      "(50000, 3, 32, 32) (10000, 3, 32, 32) (50000,) (10000,)\n",
      "float32 float32 int32 int32\n",
      "CPU times: user 834 ms, sys: 550 ms, total: 1.38 s\n",
      "Wall time: 1.38 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Data into format for library\n",
    "x_train, x_test, y_train, y_test = cifar_for_library(channel_first=True)\n",
    "\n",
    "# Load data-iterator\n",
    "#train_iter = mx.io.NDArrayIter(x_train, y_train, BATCHSIZE, shuffle=True)\n",
    "# Use custom iterator instead of mx.io.NDArrayIter() for consistency\n",
    "# Wrap as DataBatch class\n",
    "wrapper_db = lambda args: mx.io.DataBatch(data=[mx.nd.array(args[0])], label=[mx.nd.array(args[1])])\n",
    "\n",
    "print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)\n",
    "print(x_train.dtype, x_test.dtype, y_train.dtype, y_test.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.47 ms, sys: 687 µs, total: 2.16 ms\n",
      "Wall time: 1.73 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Load symbol\n",
    "sym = create_symbol()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 739 ms, sys: 778 ms, total: 1.52 s\n",
      "Wall time: 1.96 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Initialise model\n",
    "model = init_model(sym)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, Training ('accuracy', 0.34126920614596673)\n",
      "Epoch 1, Training ('accuracy', 0.49851952624839951)\n",
      "Epoch 2, Training ('accuracy', 0.57626440460947503)\n",
      "Epoch 3, Training ('accuracy', 0.63538332266325226)\n",
      "Epoch 4, Training ('accuracy', 0.6732754481434059)\n",
      "Epoch 5, Training ('accuracy', 0.71028729193341866)\n",
      "Epoch 6, Training ('accuracy', 0.73801616517285529)\n",
      "Epoch 7, Training ('accuracy', 0.75780249679897571)\n",
      "Epoch 8, Training ('accuracy', 0.77704865556978231)\n",
      "Epoch 9, Training ('accuracy', 0.79661491677336749)\n",
      "CPU times: user 2min 16s, sys: 27.4 s, total: 2min 43s\n",
      "Wall time: 2min 25s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 145s\n",
    "# Train and log accuracy\n",
    "metric = mx.metric.create('acc')\n",
    "for j in range(EPOCHS):\n",
    "    #train_iter.reset()\n",
    "    metric.reset()\n",
    "    #for batch in train_iter:\n",
    "    for batch in map(wrapper_db, yield_mb(x_train, y_train, BATCHSIZE, shuffle=True)):\n",
    "        model.forward(batch, is_train=True) \n",
    "        model.update_metric(metric, batch.label)\n",
    "        model.backward()              \n",
    "        model.update()\n",
    "    print('Epoch %d, Training %s' % (j, metric.get()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.21 s, sys: 269 ms, total: 1.48 s\n",
      "Wall time: 1.09 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "y_guess = model.predict(mx.io.NDArrayIter(x_test, batch_size=BATCHSIZE, shuffle=False))\n",
    "y_guess = np.argmax(y_guess.asnumpy(), axis=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy:  0.7742\n"
     ]
    }
   ],
   "source": [
    "print(\"Accuracy: \", sum(y_guess == y_test)/len(y_guess))"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
